{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "1789c00b-6928-4167-b677-93a4a0dbd08c",
   "metadata": {},
   "source": [
    "### 加载模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15417e78-d2f4-4d0b-b2c0-07890c080485",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "import torch\n",
    "model_name = \"/path/llama-2-7b-hf\" # 你模型的位置\n",
    "model = AutoModelForCausalLM.from_pretrained(model_name, device_map=\"auto\", torch_dtype=torch.float16)\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "# 新的分词器\n",
    "new_tokenizer = AutoTokenizer.from_pretrained(\"/path/to/merged_tokenizer_hf\") # 你保存分词器的位置\n",
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c76fb61-efc3-49d5-a7cb-32f9e610f0b7",
   "metadata": {},
   "source": [
    "###　随机扩充"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b52abe0c-0e22-4b38-9496-8c4860047d59",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 获取原先的embedding\n",
    "embeddings = model.get_input_embeddings()\n",
    "print(embeddings)\n",
    "print(embeddings(torch.LongTensor([31999])))\n",
    "\n",
    "# 扩充\n",
    "model.resize_token_embeddings(40114)\n",
    "new_embeddings = model.get_input_embeddings()\n",
    "print(new_embeddings)\n",
    "print(new_embeddings(torch.LongTensor([31999])))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8013cb8b-3634-441f-840b-65bcd91fa5c1",
   "metadata": {},
   "source": [
    "###　均值扩充"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1930339-2cc5-4ef6-80f7-0b0d5912ae8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 新增的token和在原来token相对应的字典\n",
    "token_mapping = {}\n",
    "for i in range(32000, len(new_tokenizer)):\n",
    "    # 使用 tokenizer 的 convert_ids_to_tokens 方法将索引转换为对应的 token\n",
    "    token = new_tokenizer.convert_ids_to_tokens(i)\n",
    "    # 原来的token\n",
    "    input_ids = tokenizer(token, return_tensors=\"pt\").input_ids[0]\n",
    "    # 判断是否为_\n",
    "    if input_ids[1] == 29871:\n",
    "        new_input_ids = input_ids[2:]\n",
    "    else:\n",
    "        new_input_ids = input_ids[1:]        \n",
    "    token_mapping[i] = new_input_ids\n",
    "\n",
    "# 原始输入embedding\n",
    "embeddings = model.get_input_embeddings()\n",
    "# 新完全初始化的embedding\n",
    "new_vocab_size = len(new_tokenizer)\n",
    "embedding_dim = 4096\n",
    "new_embedding = torch.nn.Embedding(new_vocab_size, embedding_dim)\n",
    "\n",
    "# 将现有Embedding层的权重赋值给新的Embedding层的前32000行\n",
    "num_to_copy = min(new_vocab_size, len(embeddings.weight))\n",
    "new_embedding.weight.data[:num_to_copy, :] = embeddings.weight.data[:num_to_copy, :]\n",
    "\n",
    "# 开始新增\n",
    "for new_token, original_tokens in token_mapping.items():\n",
    "    original_embeddings = embeddings(original_tokens)\n",
    "    mean_embedding = torch.mean(original_embeddings, dim=0)\n",
    "    new_embedding.weight.data[new_token] = mean_embedding\n",
    "\n",
    "# 更换嵌入层\n",
    "model.set_input_embeddings(new_embedding)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "562725e4-abaa-4d89-a7e0-9473e9fa208c",
   "metadata": {},
   "source": [
    "#### 扩充lm_head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d65e0e2e-91ff-4a73-bf8e-8775266f015c",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_size = 32000\n",
    "new_output_size = 40114\n",
    "lm_head = model.lm_head\n",
    "# 新的lm_head\n",
    "new_lm_head = torch.nn.Linear(in_features=4096, out_features=new_output_size, bias=False)\n",
    "# 前32000个向量不变\n",
    "new_lm_head.weight.data[:output_size, :] = lm_head.weight.data[:output_size, :]\n",
    "\n",
    "# 新增\n",
    "for new_token, original_tokens in token_mapping.items():\n",
    "    original = 0\n",
    "    for i in original_tokens:\n",
    "        original += lm_head.weight.data[i]\n",
    "    mean_para = original / len(original_tokens)\n",
    "    new_lm_head.weight.data[new_token] = mean_para\n",
    "\n",
    "# 替换模型原来的lm_head\n",
    "model.lm_head = new_lm_head\n",
    "\n",
    "# 最后完成了embedding和lm_head替换后，保存模型\n",
    "model.save_pretrained(\"llama-2-7b-extent\", max_shard_size=\"8GB\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
